#!/usr/bin/python

import sys,os,re,optparse,shutil,glob,fcntl
import matplotlib
if "DISPLAY" not in os.environ: matplotlib.use("AGG")
else: matplotlib.use("GTK")
import signal
signal.signal(signal.SIGINT,lambda *args:sys.exit(1))
from matplotlib import patches
from pylab import *
from pylab import gray as gray_
from scipy.stats.stats import trim1
from multiprocessing import Pool
from scipy.ndimage import measurements,interpolation,filters,morphology
# from scipy.misc import imsave
import traceback
import argparse
import multiprocessing
from ocrolib import number_of_processors
import ocrolib
import tables
import pyflann
from ocrolib import improc
from ocrolib import sl
from ocrolib import Record
import ocrorast
from scipy.misc import imsave

parser = argparse.ArgumentParser(description = """
%prog -o dir [options] image1 image2 ...
""")
parser.add_argument("args",default=[],nargs='*',help="input lines")
parser.add_argument("-o","--output",default="book")
parser.add_argument("-Q","--parallel",type=int,default=multiprocessing.cpu_count())
parser.add_argument("--debug",action="store_true")
parser.add_argument("--noskew",action="store_true")
parser.add_argument("--noimages",action="store_true")
args = parser.parse_args()

assert not os.path.exists(args.output)

def gsauvola(image,sigma=150.0,R=None,k=0.3,filter='uniform',scale=2.0):
    """Perform Sauvola-like binarization.  This uses linear filters to
    compute the local mean and variance at every pixel."""
    if image.dtype==dtype('uint8'): image = image / 256.0
    if len(image.shape)==3: image = mean(image,axis=2)
    if filter=="gaussian":
        filter = filters.gaussian_filter
    elif filter=="uniform":
        filter = filters.uniform_filter
    else:
        pass
    scaled = interpolation.zoom(image,1.0/scale,order=0,mode='nearest')
    s1 = filter(ones(scaled.shape),sigma)
    sx = filter(scaled,sigma)
    sxx = filter(scaled**2,sigma)
    avg_ = sx / s1
    stddev_ = maximum(sxx/s1 - avg_**2,0.0)**0.5
    s0,s1 = avg_.shape
    s0 = int(s0*scale)
    s1 = int(s1*scale)
    avg = zeros(image.shape)
    interpolation.zoom(avg_,scale,output=avg[:s0,:s1],order=0,mode='nearest')
    stddev = zeros(image.shape)
    interpolation.zoom(stddev_,scale,output=stddev[:s0,:s1],order=0,mode='nearest')
    if R is None: R = amax(stddev)
    thresh = avg * (1.0 + k * (stddev / R - 1.0))
    return array(255*(image>thresh),'uint8')

def csnormalize(image,f=0.75):
    bimage = 1*(image>mean([amax(image),amin(image)]))
    w,h = bimage.shape
    [xs,ys] = mgrid[0:w,0:h]
    s = sum(bimage)
    if s<1e-4: return image
    s = 1.0/s
    cx = sum(xs*bimage)*s
    cy = sum(ys*bimage)*s
    sxx = sum((xs-cx)**2*bimage)*s
    sxy = sum((xs-cx)*(ys-cy)*bimage)*s
    syy = sum((ys-cy)**2*bimage)*s
    w,v = eigh(array([[sxx,sxy],[sxy,syy]]))
    l = sqrt(amax(w))
    scale = f*max(image.shape)/(4.0*l)
    m = array([[1.0/scale,0],[0.0,1.0/scale]])
    w,h = image.shape
    c = array([cx,cy])
    d = c-dot(m,array([w/2,h/2]))
    image = interpolation.affine_transform(image,m,offset=d,order=1)
    return image

def classifier_normalize(image):
    cimage = array(image*1.0/(1e-6+amax(image)),'f')
    cimage = improc.isotropic_rescale(cimage,32)
    cimage = csnormalize(cimage)
    return cimage

ppath = "full-300-estats.h5"
with tables.openFile(ppath) as pdb:
    n = len(pdb.root.classes)
    nprotos = n
    protos = array(pdb.root.patches[:n],'f')
    pclasses = array(pdb.root.classes[:n],'i')
    e_means = array(pdb.root.e_means[:n],'f')
    e_sigmas = array(pdb.root.e_sigmas[:n],'f')
    e_counts = array(pdb.root.e_counts[:n,0],'f')
print "got",nprotos,"prototypes"

class NNIndex(pyflann.FLANN):
    """A wrapper for FLANN that handles context management correctly."""
    def __init__(self,*args,**kw):
        if "algorithm" not in kw: kw["algorithm"] = "kmeans"
        if "branching" not in kw: kw["branching"] = 20
        if "target_precision" not in kw: kw["target_precision"] = 0.995
        pyflann.FLANN.__init__(self,*args,**kw)
    def __enter__(self):
        return self
    def __exit__(self,*args):
        self.delete_index()

print "building index (FIXME)"
nnprotos = NNIndex()
nnprotos.build_index(protos.reshape(len(protos),prod(protos.shape[1:])))
print "done"

def bounding_boxes_math(image):
    """Compute the bounding boxes in the image; returns mathematical
    coordinates."""
    image = (image>mean([amax(image),amin(image)]))
    image,ncomponents = measurements.label(image)
    objects = measurements.find_objects(image)
    result = []
    h,w = image.shape
    for o in objects:
        y1 = h-o[0].start
        y0 = h-o[0].stop
        x0 = o[1].start
        x1 = o[1].stop
        c = (x0,y0,x1,y1)
        result.append(c)
    return result

def select_plausible_char_bboxes(bboxes,dpi=300.0):
    """Performs simple heuristic checks on character bounding boxes;
    removes boxes that are too small or too large, or have the wrong
    aspect ratio."""
    s = dpi/300.0
    result = []
    for b in bboxes:
        x0,y0,x1,y1 = b
        w = x1-x0
        if w<s*5: continue
        h = y1-y0
        if h<s*5: continue
        a = w*1.0/h
        if a>s or a<0.25: continue
        if w>s*100: continue
        if h>s*100: continue
        result.append(b)
    return result

def is_binary(image):
    """Check whether an input image is binary"""
    return sum(image==amin(image))+sum(image==amax(image)) > 0.99*image.size

def estimate_skew_angle(image):
    """Estimate the skew angle of a document image, by first finding
    character bounding boxes, then invoking the RAST text line finder
    (without constraints) in order to find the longest line."""
    assert is_binary(image)
    finder = ocrorast.make_TextLineRAST2()
    finder.setMaxLines(1)
    nobjects = 0
    bboxes = bounding_boxes_math(image)
    bboxes = select_plausible_char_bboxes(bboxes)
    if bboxes==[]: 
        return 0.0
    for c in bboxes:
        finder.addChar(*c)
        nobjects += 1
    finder.compute()
    if finder.nlines()<1: return 0.0
    m = finder.getLine_m(0)
    del finder
    return math.atan(m)



def remove_small_components(image,r=3):
    """Remove any connected components that are smaller in both dimension than r"""
    image,ncomponents = measurements.label(image)
    objects = measurements.find_objects(image)
    for i in range(len(objects)):
        o = objects[i]
        if o[0].stop-o[0].start>r: continue
        if o[1].stop-o[1].start>r: continue
        c = image[o]
        c[c==i+1] = 0
    return (image!=0)

def classify(img):
    img = classifier_normalize(img)
    ns,ds = nnprotos.nn_index(array([img.ravel()],'f'),2)
    n = ns.ravel()[0]
    d = (ds.ravel()[0]-e_means[n])/max(1.0,e_sigmas[n])
    if pclasses[n]>=65536:
        cls = "~"
    else:
        cls = unichr(pclasses[n])
    w = log(e_counts[n]+1)
    return cls,d,w,n

def try_threshold(raw,k):
    image = gsauvola(raw,k=k)
    image = array(255*(image<0.5*(amax(image)+amin(image))),'B')
    labels,n = measurements.label(image)
    objects = measurements.find_objects(labels)
    good = []
    total = 0.0
    images = zeros(image.shape)
    for i,s in enumerate(objects):
        img = image[s]
        if sl.width(s)<10 or sl.width(s)>200 or sl.height(s)<10 or sl.height(s)>200:
            images[s] = image[s]*(labels[s]==i+1)
            continue
        cls,d,w,n = classify(img)
        # if re.search(r'^[A-HK-Za-km-z02-9]$',cls):
        if d<4.0 and re.search(r'^[A-Za-z0-9]$',cls) and w>1.0:
            good.append(Record(i=i,s=s,cls=cls))
            total += w
        elif not re.search(r'^[A-Za-z0-9.,"\'()-:;?]$',cls):
            images[s] = image[s]*(labels[s]==i+1)
    mh = median([sl.height(s.s) for s in good])
    mw = median([sl.width(s.s) for s in good])
    good = [s for s in good if sl.height(s.s)>mh/2 and sl.height(s.s)<mh*3]
    chars = zeros(image.shape)
    for s in good:
        chars[s.s] = image[s.s]*(labels[s.s]==s.i+1)
    return Record(total=total,
                  good=good,
                  k=k,
                  image=image,
                  chars=chars,
                  images=images,
                  meds=(mw,mh))

def smart_preproc(raw):
    best = None

    for k in linspace(0.05,0.15,3.0):
        out = try_threshold(raw,k)
        if best is not None and out.total<best.total: continue
        best = out

    # print best.total,best.k,best.meds
    mw,mh = best.meds
    if args.debug:
        clf(); ion(); gray()
        subplot(221); imshow(best.chars)
        xlabel("chars")
        for s in best.good:
            text(s.s[1].start,s.s[0].stop,s.cls,color="red")
        subplot(222); imshow(best.images)
        xlabel("non-chars")
        ginput(1,0.01)

    # raw_images = remove_small_components(best.image-best.text,r=3)
    if args.noimages:
        chars = best.image
        images = zeros(chars.shape,'B')
    else:
        tmask = filters.gaussian_filter(best.chars/255.0,(4*mh,6*mw),mode='constant')
        imask = filters.gaussian_filter(best.images/255.0,(3*mh,3*mw),mode='constant')
        select = (2*tmask>=imask)
        chars = best.image*select
        images = best.image*(1-select)

    if args.debug:
        subplot(223); imshow(0.8*images+0.2*(255*(1-select)))
        xlabel("images")
        ginput(1,0.01)

    if args.debug:
        subplot(224); imshow(0.8*chars+0.2*(255*select))
        xlabel("text")
        ginput(1,0.01)

    return Record(k=best.k,
                  mw = mw,
                  mh = mh,
                  good=best.good,
                  binarized=best.image,
                  images=images,
                  chars=chars)


if args.debug:
    figure(figsize=(12,9.5),dpi=100)

def dread(fname):
    """Read an image, similar to imread.  However, imread flips JPEG images;
    this fixes that."""
    _,ext = os.path.splitext(fname)
    image = imread(fname)
    if ext.lower() in [".jpg",".jpeg"]:
        image = image[::-1,:,...]
    return image

def process1(job):
    try:
        fname,output = job
        print "===",fname,output
        raw = dread(fname)
        p = smart_preproc(raw)
        p.raw = raw
        print p.k,p.mw,p.mh,len(p.good)
        if not args.noskew:
            a = estimate_skew_angle(p.chars)
            p.chars = interpolation.rotate(p.chars,-a*180/pi,mode='nearest',order=0)
            p.images = interpolation.rotate(p.images,-a*180/pi,mode='nearest',order=0)
            p.binarized = interpolation.rotate(p.binarized,-a*180/pi,mode='nearest',order=0)
            p.raw = interpolation.rotate(p.raw,-a*180/pi,mode='nearest',order=0)
            if args.debug: 
                cla(); xlabel("skew correction by %f"%a)
                imshow(p.chars); ginput(1,0.001)
        bin_file = ocrolib.fvariant(output,"png","bin")
        imsave(bin_file,amax(p.chars)-p.chars)
        fbin_file = ocrolib.fvariant(output,"png","fbin")
        imsave(fbin_file,amax(p.binarized)-p.binarized)
        img_file = ocrolib.fvariant(output,"png","img")
        imsave(img_file,amax(p.images)-p.images)
        gray_file = ocrolib.fvariant(output,"png")
        imsave(gray_file,p.raw)
        print "wrote",bin_file
        return None
    except:
        traceback.print_exc()
        return None

os.mkdir(args.output)
if args.debug: args.parallel = 1

jobs = []
for i,arg in enumerate(args.args):
    jobs.append((arg,args.output+"/%04d.png"%i))
if args.parallel<=1:
    for job in jobs:
        process1(job)
else:
    pool = Pool(processes=args.parallel)
    result = pool.map(process1,jobs)
